function GapfilledDataSet = PredictMissingMarkers_ForSubmission(Data_gaps,varargin)
% PredictMissingMarkers 
% searches for gaps ("NaN"s) in the marker coordinates and fills these gaps 
% using the correlation structure in the marker coordinates.
%
% Input:
% Required:
% Data_gaps: matrix with marker data organized in the form
%            [ x1(t1), y1(t1), z1, x2, y2, z2, x3, ..., zn(t1)
%              x1(t2), y1(t2), ...
%               ...                                     ...
%              x1(tm), y1(tm), ...                    , zn(tm)]
%
% Optional parameter - value pairs:
% 'MMRecAlgorithm': Reconstruction strategy for gaps in multiple markers
%                   (integer from 1-4, representing R1-R4. Default: 4)
% 'weightexponent': Exponent for calculation of weight-factors.
%                   (Default: -3)
% 'MMweight':       Weight of missing marker relative to minimum weight
%                   (Default: 1e-3)
% 'R4threshold':    Cutoff distance for distal markers in R4, relative to
%                   average Euclidean distance between all markers. (Default: 0.5)
% 'MinCumSV':       Cumumlative sum of normalized singular values that determines the
%                   number of PC-vectors included in the analysis (Default: 0.99)
%
%  
% Output:
% GapfilledDataSet: Matrix in the same form as the input matrix
%                           where the missing data frames is replaced by  
%                           reconstructed marker trajectories 


%% Set program parameters:

%default values
parser = inputParser;
defaultMultipleMissingMarkerRecAlgorithm = 4;
defaultSpatialWeightExponent = -3;
defaultSpatialMMweight = 1e-3;
defaultR4threshold = 0.5;
defaultMinCumSV = 0.99; 
expectedMultipleMarkerRecAlgorithm = [1 2 3 4];

checkRecAlg = @(x) isnumeric(x) && ismember(x,expectedMultipleMarkerRecAlgorithm);
checkCumVar = @(x) isnumeric(x) && x>0 && x<1;

addRequired(parser,'Data_gaps',@isnumeric);
addOptional(parser,'weightexponent',defaultSpatialWeightExponent,@isnumeric);
addOptional(parser,'MMweight',defaultSpatialMMweight,@isnumeric);
addOptional(parser,'R4threshold',defaultR4threshold,@isnumeric);
addOptional(parser,'MinCumSV',defaultMinCumSV,checkCumVar);
addOptional(parser,'MMRecAlgorithm',defaultMultipleMissingMarkerRecAlgorithm,...
    checkRecAlg);


%set parameters:
parse(parser,Data_gaps,varargin{:});

weightexponent = parser.Results.weightexponent;
MMweight = parser.Results.MMweight;
R4threshold = parser.Results.R4threshold;
MinCumSV = parser.Results.MinCumSV;


%% center the data by subtracting a mean trajectory,

[frames,columns] = size(Data_gaps); 
% Detect which columns have gaps and where
columns_with_gaps = find(any(isnan(Data_gaps),1));
MarkersWithGaps = columns_with_gaps(3:3:end)./3;
frames_with_gaps = find(any(isnan(Data_gaps),2));

if isempty(frames_with_gaps)
   warning('It does not appear to be any gaps in the submitted data. Make sure that gaps are represented by nan-s')
   GapfilledDataSet = Data_gaps;
   return
end

Data_without_compromized_markers = Data_gaps;
Data_without_compromized_markers(:,columns_with_gaps)=[];

mean_trajectory.x = ...
mean(Data_without_compromized_markers(:,...
                            1:3:columns-size(columns_with_gaps,2) ),2);
mean_trajectory.y = ...
    mean(Data_without_compromized_markers(:,...
                            2:3:columns-size(columns_with_gaps,2) ),2);
mean_trajectory.z = ...
    mean(Data_without_compromized_markers(:,...
                            3:3:columns-size(columns_with_gaps,2) ),2);

Data_gaps(:,1:3:columns) = ...
    Data_gaps(:,1:3:columns)-...
        repmat(mean_trajectory.x,1,columns/3);   
Data_gaps(:,2:3:columns) = ...
    Data_gaps(:,2:3:columns)-...
        repmat(mean_trajectory.y,1,columns/3); 
Data_gaps(:,3:3:columns) = ...
    Data_gaps(:,3:3:columns)-...
        repmat(mean_trajectory.z,1,columns/3);  

%% Allocate space for output matrix:
GapfilledDataSet = Data_gaps;

%% Choose reconstruction strategy
switch parser.Results.MMRecAlgorithm
    case 1
        % disp('R1 strategy')
        %Determine weights based on spatial proximity
        ReconstructedFullDataSet = reconstruct(Data_gaps);
        indexes_with_gaps = find(isnan(Data_gaps));
        GapfilledDataSet(indexes_with_gaps) = ReconstructedFullDataSet(indexes_with_gaps);
        
    case 2
    % disp('R2 strategy')
        for i=MarkersWithGaps
            Data_gaps_removedFrames = Data_gaps;
            %Find overlapping gaps (if they exist):
            CumulativeSum = cumsum(isnan(GapfilledDataSet(:,1:3:end)),2);
            CumulativeSum = CumulativeSum.*repmat(isnan(GapfilledDataSet(:,3*i-2)),...
                1,size(CumulativeSum,2)); 
            %now the last columns in "CumulativeSum" is zero if there are
            %no overlapping gaps in other markers. Otherwise, it is the
            %number of markers that overlap (included marker i)
            if max(CumulativeSum(:,end))>1 %Remove markers with overlapping gaps
                rowsWithOverlap = find(CumulativeSum(:,end)>0);
                temp = diff([zeros(length(rowsWithOverlap),1) CumulativeSum(rowsWithOverlap,:)],1,2);
                OverlappingMarkers = find(sum(temp,1)>0);
                OverlappingMarkers(find(OverlappingMarkers==i)) = []; %remove marker i
                OverlappingMarkers = repmat(OverlappingMarkers,3,1);
                OverlappingCols = reshape(3*OverlappingMarkers-repmat([0 1 2]',1,...
                    size(OverlappingMarkers,2)),1,3*size(OverlappingMarkers,2));
                %replace columns with overlapping gaps by zeros:
               Data_gaps_removedFrames(:,OverlappingCols) = 0;
            end
            %remove frames with incomplete information:
            Data_gaps_removedFrames(frames_with_gaps,:) = [];
            %add frames with gaps in marker i to the end of the matrix:
            frames2rec = find(any(isnan(Data_gaps(:,3*i-2:3*i)),2));
            Data_gaps_removedFrames = cat(1,Data_gaps_removedFrames,...
                Data_gaps(frames2rec,:));
            %reconstruct frame
            TempReconstructedData = reconstruct(Data_gaps_removedFrames);
               
            GapfilledDataSet(frames2rec,3*i-2:3*i) = ...
                TempReconstructedData(end-length(frames2rec)+1:end,3*i-2:3*i); 
        end %end filling marker i
      
    case 3
        % disp('R3 strategy')
        for i=MarkersWithGaps
            Data_gaps_removedCols = Data_gaps;
            Data_gaps_removedCols(:,columns_with_gaps) = 0;
            Data_gaps_removedCols(:,3*i-2:3*i) = Data_gaps(:,3*i-2:3*i);
            TempReconstructedData = reconstruct(Data_gaps_removedCols);
            ReconstructedFullDataSet(:,3*i-2:3*i) = TempReconstructedData(:,3*i-2:3*i);
        end
        GapfilledDataSet = Data_gaps;
        indexes_with_gaps = find(isnan(Data_gaps));
        GapfilledDataSet(indexes_with_gaps) = ReconstructedFullDataSet(indexes_with_gaps);    
        
        case 4
      % disp('R4 strategy')
        for i=MarkersWithGaps
            % remove columns distal to marker i
            EuclDist2Markers = distance2marker(Data_gaps,3*i-2:3*i);
            thresh =  R4threshold.*mean(EuclDist2Markers);
            %find markers with gaps and Eucl. distance to marker i greater
            %than threshold
            Cols2Zero = find(reshape(repmat(EuclDist2Markers,3,1),1,columns)>thresh ...
                & any(isnan(Data_gaps),1));
            Data_gaps_removedCols = Data_gaps;
            Data_gaps_removedCols(:,Cols2Zero) = 0;
            Data_gaps_removedCols(:,3*i-2:3*i) = Data_gaps(:,3*i-2:3*i);
            %Find overlapping gaps in marker i (if they exist):
            CumulativeSum = cumsum(isnan(Data_gaps_removedCols(:,1:3:end)),2);
            CumulativeSum = CumulativeSum.*repmat(isnan(Data_gaps_removedCols(:,3*i-2)),...
                1,size(CumulativeSum,2));  
            %now the last columns in "CumulativeSum" is zero if there are
            %no overlapping gaps in other markers. Otherwise, it is the
            %number of markers that overlap (included marker i)
            if max(CumulativeSum(:,end))>1 %Reconstruct gaps that overlap
               framesWithOverlap = find(CumulativeSum(:,end)>1);
               n_framesWithOverlap = length(framesWithOverlap);
               recFrame = 1;
               while n_framesWithOverlap>0
                   % find frames with overlap for each specific marker combination. 
                   gappedFrames = find(ismember(CumulativeSum,...
                       CumulativeSum(framesWithOverlap(recFrame),:),'rows'));
                   %remove all frames with gaps, then add the gap we want
                   %to reconstruct
                   CompleteFrames = find(~any(isnan(Data_gaps_removedCols),2));
                   CompleteAndGappedFrames = ([CompleteFrames' gappedFrames']);
                    %reconstruct overlapping gap:
                    reconstructedData = reconstruct(Data_gaps_removedCols(CompleteAndGappedFrames,:));
                    fillframes = length(CompleteFrames)+1:length(CompleteAndGappedFrames);
                    %fill gaps in marker i
                    GapfilledDataSet(gappedFrames,3*i-2:3*i) = ...
                            reconstructedData(fillframes,3*i-2:3*i);
                    %remove filled frames from CumulativeSumOverlap
                    recFrame = recFrame + length(gappedFrames); 
                    n_framesWithOverlap = n_framesWithOverlap-length(gappedFrames);
               end
            end % end filling overlapping gaps
        
            %fill the non-overlapping gaps:
            gappedFrames = find(any(isnan(GapfilledDataSet(:,i*3)),2));
            if ~isempty(gappedFrames) %some gaps might have been completely filled
                %by filling the overlapping gap. Skip these
                CompleteFrames = find(~any(isnan(Data_gaps_removedCols),2));
                CompleteAndGappedFrames = ([CompleteFrames' gappedFrames']);
                fillframes = length(CompleteFrames)+1:length(CompleteAndGappedFrames);
                reconstructedData = reconstruct(Data_gaps_removedCols(CompleteAndGappedFrames,:));
                GapfilledDataSet(gappedFrames,3*i-2:3*i) = ...
                                reconstructedData(fillframes,3*i-2:3*i);
            end
        end %end filling marker i         
        
    otherwise 
        disp('Error: invalid reconstruction approach')
        GapfilledDataSet = nan;
end

%% Add mean trajectory

GapfilledDataSet(:,1:3:columns) = ...
    GapfilledDataSet(:,1:3:columns) + ...
    repmat(mean_trajectory.x,1,columns/3);
GapfilledDataSet(:,2:3:columns) = ...
    GapfilledDataSet(:,2:3:columns) + ...
    repmat(mean_trajectory.y,1,columns/3);
GapfilledDataSet(:,3:3:columns) = ...
    GapfilledDataSet(:,3:3:columns) + ...
    repmat(mean_trajectory.z,1,columns/3);


%% reconstruction function
function Reconstruction = reconstruct(Data2reconstruct)
    %Determine frames and columns with gaps
    ColsWithGaps = find(any(isnan(Data2reconstruct),1));   
    %find the weight vector based on Euclidean distances between markers   
    weightvector = distance2marker(Data_gaps,ColsWithGaps); 
    % for "all at once"-approach: keep only the smallest distances for weights
    if size(weightvector,1)>1  
        weightvector = min(weightvector);
    end
    weightvector = weightvector.^weightexponent;
    weightvector(ColsWithGaps(3:3:end)/3) = ...
        MMweight.*min(weightvector(isfinite(weightvector)));
    
    %Avoid division by zero: This can happen because some columns are zeroed, giving them infinite
    %weights. To avoid problems during normalization, give infinite weights
    %a finite value (the actual value does not matter, since the columns are zeroed)
    weightvector(find(~isfinite(weightvector))) = 1;
    
    %define matrices needed for reconstruction:
    M = Data2reconstruct;

    M_zeros = M;
    M_zeros(:,ColsWithGaps)=0;

    N_no_gaps = Data2reconstruct(~any(isnan(M),2),:);
  

    N_zeros = N_no_gaps; 
    N_zeros(:,ColsWithGaps) = 0;

    %normalize to unit variance, then add weighting vector:
    mean_N_no_gaps    = mean(N_no_gaps,1);
    mean_N_zeros = mean(N_zeros,1);
    stdev_N_no_gaps   = std(N_no_gaps,1,1);
    stdev_N_no_gaps(find(stdev_N_no_gaps==0)) = 1;
    %the zeroed colums should be remain zero. Set the stdev_N_no_gaps to 
    % one for the zeroed columns, so that they remain zero after
    % normalization
    M_zeros = (M_zeros - repmat(mean_N_zeros,size(M_zeros,1),1))./...
                           repmat(stdev_N_no_gaps,size(M_zeros,1),1).*...
                           repmat(reshape([1 1 1]'*weightvector,1,[]),...
                                   size(M_zeros ,1),1);

    N_no_gaps =(N_no_gaps-repmat(mean_N_no_gaps,size(N_no_gaps ,1),1))./...
                            repmat(stdev_N_no_gaps,size(N_no_gaps ,1),1).*...
                            repmat(reshape([1 1 1]'*weightvector,1,[]),...
                                           size(N_no_gaps ,1),1);

    N_zeros = (N_zeros - repmat(mean_N_zeros,size(N_zeros ,1),1))./...
                           repmat(stdev_N_no_gaps,size(N_no_gaps ,1),1).*...
                           repmat(reshape([1 1 1]'*weightvector,1,[]),...
                                           size(N_no_gaps ,1),1);

    
    [PC_vectors__no_gaps,sqrtEigvals_no_gaps] = PCA(N_no_gaps);

    [PC_vectors__zeros,sqrtEigvals_zeros] = PCA(N_zeros);
    
    % Select the number of PV-vectors to include in the analysis
    n_eig = [find(cumsum(sqrtEigvals_no_gaps)>...
        MinCumSV*sum(sqrtEigvals_no_gaps),1,'first');
        find(cumsum(sqrtEigvals_zeros)>...
        MinCumSV*sum(sqrtEigvals_zeros),1,'first')];
    n_eig = max(n_eig);
    

    PC_vectors__no_gaps = PC_vectors__no_gaps(:,1:n_eig);
    PC_vectors__zeros = PC_vectors__zeros(:,1:n_eig);
    % Calculate Transformation Matrix for Principal Movements
    T = PC_vectors__no_gaps'*PC_vectors__zeros;
    % Transform Data first into incomplete-, then into full-PC basis system.
    ReconstructedData = M_zeros*PC_vectors__zeros*T*PC_vectors__no_gaps';
    % (Equation 1 in Federolfs 2013 paper)

    % Reverse normalization
    ReconstructedData = repmat(mean_N_no_gaps,size(Data2reconstruct,1),1)...
        +ReconstructedData.*repmat(stdev_N_no_gaps,size(ReconstructedData ,1),1)./...
        repmat(reshape([1 1 1]'*weightvector,1,[]),size(M_zeros ,1),1);
    %prepare  output
    Reconstruction = Data2reconstruct;
    for j = columns_with_gaps 
        Reconstruction(:,j) = ReconstructedData(:,j);

    end

end %end function reconstruct


end %end PredictMissingMarkers

function [PC,sqrtEV] = PCA(Data)

[N,M] = size(Data);
Y = Data / sqrt(N-1);
[U,sqrtEV,PC] = svd(Y,'econ');
sqrtEV = diag(sqrtEV);

end

%spatial distance function
    function [distArray] = distance2marker(MarkerData,colWithGaps)
        [n,m] = size(MarkerData);
        MarkerWithGaps = colWithGaps(3:3:end)/3;
        nMarkerWidthGaps = length(colWithGaps)/3;
        MarkerData = reshape(MarkerData',3,m/3,n);

        distArray = nan(nMarkerWidthGaps,m/3,n);
        for i=1:n
            distArray(:,:,i) =pdist2(MarkerData(:,MarkerWithGaps,i)',...
                MarkerData(:,:,i)','euclidean');
        end
        distArray = nanmean(distArray,3);
    end









